{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "47d5313e-c29d-4581-a9c7-a45122337069",
      "metadata": {
        "id": "47d5313e-c29d-4581-a9c7-a45122337069"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"300\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "# The Forward-Forward Algorithm with a Spiking Neural Network\n",
        "### Tutorial written by Ethan Mulle and Abhinandan Singh\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "oll2NNFeG1NG",
      "metadata": {
        "id": "oll2NNFeG1NG"
      },
      "source": [
        "The following tutorial introduces how to implement the Forward-Forward algorithm proposed by Geoffrey Hinton into a spiking neural network (SNN).\n",
        "> <cite> [Geoffrey Hinton. The Forward-Forward Algorithm: Some Preliminary\n",
        "Investigations. 2022. arXiv: 2212.13345 [cs.LG].](https://arxiv.org/abs/2212.13345) </cite>\n",
        "\n",
        "The Forward-Forward algorithm is inspired from a model of learning in the cortex, and offers a potentially low-power, local, backprop-free alternative to training a model. Of course, it is a challenging algorithm to scale up to more challenging tasks. But this is a good starting point to learn from.\n",
        "\n",
        "For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)\n",
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ELOnnEYi4YY8",
      "metadata": {
        "id": "ELOnnEYi4YY8"
      },
      "source": [
        "Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "hDnIEHOKB8LD",
      "metadata": {
        "id": "hDnIEHOKB8LD"
      },
      "outputs": [],
      "source": [
        "!pip install snntorch --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "WL487gZW1Agy",
      "metadata": {
        "id": "WL487gZW1Agy"
      },
      "outputs": [],
      "source": [
        "import torch, torch.nn as nn\n",
        "import snntorch as snn\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "EYf13Gtx1OCj",
      "metadata": {
        "id": "EYf13Gtx1OCj"
      },
      "source": [
        "## 1. The MNIST Dataset\n",
        "\n",
        "### 1.1 Dataloading\n",
        "\n",
        "- Set up dataloaders for the MNIST training and test datasets.\n",
        "\n",
        "- Define a data transformation pipeline, and iterate over the first batch of the training loader to inspect the size of the data.\n",
        "\n",
        "- The transformations include converting images to grayscale, transforming them to tensors, normalizing pixel values, and flattening the tensors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3GdglZjK04cb",
      "metadata": {
        "id": "3GdglZjK04cb"
      },
      "outputs": [],
      "source": [
        "from torchvision.datasets import MNIST\n",
        "from torchvision.transforms import Compose, Grayscale, ToTensor, Normalize, Lambda\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "\n",
        "transform = Compose([\n",
        "    Grayscale(),\n",
        "    ToTensor(),\n",
        "    Normalize((0,), (1,)),\n",
        "    Lambda(lambda x: torch.flatten(x))])\n",
        "\n",
        "# Train set\n",
        "mnist_train = MNIST('./data/', train=True, download=True, transform=transform)\n",
        "train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)\n",
        "\n",
        "# Test set\n",
        "mnist_test = MNIST('./data/', train=False, download=True, transform=transform)\n",
        "test_loader = DataLoader(mnist_train, batch_size=128, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "XOT9ESAC5d1G",
      "metadata": {
        "id": "XOT9ESAC5d1G"
      },
      "source": [
        "## 2. Forward-Forward Algorithm Theory\n",
        "The Forward-Forward algorithm is a learning rule inspired by Boltzmann machines and Noise Contrastive Estimation. The core idea is that each layer learns to tell apart *positive* examples (which we treat as '*good*') from *negative* examples (which we treat as '*bad*'. It works by replacing the forward and backward passes of backpropagation with two forward passes:\n",
        "- The positive pass\n",
        "- The negative pass.\n",
        "\n",
        "In the positive pass, a training example plus its label is fed through the network. Every layer computes a *goodness* measure (e.g., sum of squared neuron outputs). The weights are updated to *increase* goodness for that class's *correct* examples.  \n",
        "\n",
        "In teh negative pass, we feed in *bad* examples. For example, the same input might be applied with the *wrong* class label. The network again measures the goodness in each layer, but now the weights are updated to decrease goodness. It is thought that over many iterations, each layer refines local filters or patterns that best distinguish each class from noise or other classes, without needing a global backpropagation signal.\n",
        "\n",
        "A high-level illustration of the algorithm is shown below.\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial_forward-forward/forward-forward.png?raw=true' width=\"800\">\n",
        "</center>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "B1B_tzLsihYM",
      "metadata": {
        "id": "B1B_tzLsihYM"
      },
      "source": [
        "Hinton proposes that goodness can be measured through a metric, such as the sum of the square of neural activities. The goal is to make the goodness to exceed some threshold value for real data and to fall below the thresholdfor negative data. In practice, Hinton applies the sigmoid function $\\sigma$ to this deviation from the threshold $\\theta$ such that the result can be interpreted as the probability that an input vector is positive (i.e., real):\n",
        "\n",
        "\\begin{equation}\n",
        "    p(positive) = \\sigma \\left(\\sum_j y_j^2 - \\theta \\right)\n",
        "\\end{equation}\n",
        "\n",
        "where $y_j$ is the activity of hidden unit $j$."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "GB77EmhgfaN0",
      "metadata": {
        "id": "GB77EmhgfaN0"
      },
      "source": [
        "## 3. Implementation\n",
        "We follow and modify the Forward-Forward algorithm implementation found here:\n",
        "> <cite> [Mohammad Pezeshki. pytorch_forward_forward. Github Repo January 2023.](https://github.com/mpezeshki/pytorch_forward_forward) </cite>\n",
        "\n",
        "To do so, we must define both a layer and a network class."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4D_1BCXKkRbe",
      "metadata": {
        "id": "4D_1BCXKkRbe"
      },
      "source": [
        "### 3.1 Defining the Layer Class\n",
        "- **LeakyLayer**: is a custom leaky integrate-and-fire layer with additional attributes. The class is designed to be used as a layer in a neural network.\n",
        "\n",
        "- The **forward method**: defines the computations for a forward pass through the 'LeakyLayer'. The membrane potential is initialized, the input is normalized and weighted, and then passed through a leaky integrate-and-fire neuron, in order to return the resulting membrane potential.\n",
        "\n",
        "- The **train method**: performs a training loop over a specified number of epochs. It computes the goodness for positive and negative data, calculates the loss based on these goodness values, locally backpropagates the gradients, and updates the parameters using the optimizer. The method returns the membrane potentials for positive and negative data after training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "EqjAjimIkZLu",
      "metadata": {
        "id": "EqjAjimIkZLu"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "from torch.optim import Adam\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class LeakyLayer(nn.Linear):\n",
        "\n",
        "    def __init__(self, in_features, out_features, activation, bias=False):\n",
        "        super().__init__(in_features, out_features, bias=bias)\n",
        "\n",
        "        # Enable the choice between a spiking neuron and a standard artificial neuron (ReLU activation)\n",
        "        if activation == \"lif\":\n",
        "          self.activation = snn.Leaky(beta=0.8)\n",
        "          self.lif = True\n",
        "        else:\n",
        "          self.activation = nn.ReLU()\n",
        "          self.lif = False\n",
        "\n",
        "        self.opt = Adam(self.parameters(), lr=0.03)\n",
        "        self.threshold = 5.0\n",
        "        self.num_epochs = 1000\n",
        "\n",
        "    def forward(self, x):\n",
        "        if self.lif == True:\n",
        "          mem = self.activation.init_leaky() # initialize membrane potential\n",
        "\n",
        "        x_direction = x / (torch.norm(x, p=2, dim=1, keepdim=True) + 1e-4) # normalize the input\n",
        "\n",
        "        # Linear layer - not using nn.Linear(...) so we can define the update rule\n",
        "        weighted_input = torch.mm(x_direction, self.weight.T.to(device)) # MatMul between normalized input and weight matrix\n",
        "\n",
        "        if self.lif == True:\n",
        "          spk, potential = self.activation(weighted_input, mem) # note: only one step in time. Wrap in a for-loop to iterate for longer.\n",
        "        else:\n",
        "          potential = self.activation(weighted_input)\n",
        "\n",
        "        return potential  # to be used in subsequent layers or as the final output of the last layer\n",
        "\n",
        "    def train(self, x_pos, x_neg):\n",
        "\n",
        "        tot_loss = []  # store the loss values for each layer in each epoch\n",
        "        for _ in tqdm(range(self.num_epochs), desc=\"Training LeakyLayer\"):\n",
        "\n",
        "            # Compute goodness\n",
        "            g_pos = self.forward(x_pos).pow(2).mean(1) # positive data\n",
        "            g_neg = self.forward(x_neg).pow(2).mean(1) # negative data\n",
        "\n",
        "            # take the mean of differences between goodness and threshold across pos and neg samples\n",
        "            loss = F.softplus(torch.cat([-g_pos + self.threshold, g_neg - self.threshold])).mean()\n",
        "\n",
        "            self.opt.zero_grad()\n",
        "            loss.backward() # local backward-pass\n",
        "            self.opt.step() # update weights\n",
        "            tot_loss.append(loss)\n",
        "\n",
        "        # returns the final membrane potentials (activations) for positive and negative examples after training.\n",
        "        # detach() ensures no further backward pass is possible\n",
        "        output = self.forward(x_pos).detach(), self.forward(x_neg).detach()\n",
        "        return (output, tot_loss)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Shh0lapBkZxO",
      "metadata": {
        "id": "Shh0lapBkZxO"
      },
      "source": [
        "### 3.2 Defining the Network\n",
        "\n",
        "- The **predict** method overlays label information onto the input, passes it through each layer of the neural network, computes goodness values at each layer, accumulates them for each label, and returns the predicted labels based on the maximum goodness.\n",
        "\n",
        "- The **train** method iterates over the layers of the neural network, prints information about the current layer being trained, and updates the input tensors ('h_pos' and 'h_neg') by calling the 'train' method of each layer. This allows the layers to learn and adapt their parameters during the training process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5qkfsA69aEcQ",
      "metadata": {
        "id": "5qkfsA69aEcQ"
      },
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n",
        "\n",
        "    def __init__(self, dims, activation):\n",
        "        super().__init__()\n",
        "\n",
        "        self.layers = nn.ModuleList([\n",
        "            LeakyLayer(dims[d], dims[d + 1], activation) for d in range(len(dims) - 1)  # define a multi-layer network\n",
        "        ])\n",
        "\n",
        "\n",
        "    def predict(self, x):\n",
        "        goodness_per_label = [] # used to store the goodness values for each label\n",
        "\n",
        "        for label in range(10): # loop over the 10 possible labels\n",
        "            h = overlay_y_on_x(x, label)  # overlay label information on top of the input data. function defined in next code-block.\n",
        "            goodness = []\n",
        "\n",
        "            for layer in self.layers: # inner loop over the layers\n",
        "                h = layer(h)\n",
        "                goodness.append(h.pow(2).mean(1)) # compute and store goodness for every layer\n",
        "\n",
        "            goodness_per_label.append(sum(goodness).unsqueeze(1)) # sum goodness values between all layers for one sample of data\n",
        "\n",
        "        goodness_per_label = torch.cat(goodness_per_label, 1)  # convert list to tensor\n",
        "\n",
        "        return goodness_per_label.argmax(1)  # returns maximum value giving the predicted label for each input\n",
        "\n",
        "    def train(self, x_pos, x_neg):\n",
        "        h_pos, h_neg = x_pos, x_neg\n",
        "        layer_losses = []\n",
        "        for i, layer in enumerate(self.layers):\n",
        "            print('Training layer', i, '...')\n",
        "            outputs, loss = layer.train(h_pos, h_neg)  # update the weight matrix for each layer based on the local loss (defined by \"goodness\")\n",
        "            h_pos, h_neg = outputs\n",
        "            layer_losses.append(loss)\n",
        "        return torch.Tensor(layer_losses)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "lY2WozWrkwBp",
      "metadata": {
        "id": "lY2WozWrkwBp"
      },
      "source": [
        "## 4. Training the Network"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6741ee0b",
      "metadata": {
        "id": "6741ee0b"
      },
      "source": [
        "### 4.1 Generate data sampels\n",
        "\n",
        "Before training, we need to get our datasets in a format that is required by the Forward-Forward algorithm. Recall that the labels and input data are somehow going to be combined.\n",
        "\n",
        "- The **'overlay_y_on_x'** function:\n",
        "    - **Input: 'x'**: input data tensor\n",
        "    - **Input: 'y'**: input label\n",
        "    - **Output: 'x_'**: The first 10 pixels of *x* are zero'd out and replaced with a one-hot-encoded representation of the label *y*. I.e., the label is 'printed' in pixel form on the input. The y-label can also be randomized to generate wrong labels, or 'negative' samples.\n",
        "\n",
        "This function is applied to the MNIST dataset and passed to your network. Each layer is updated based on the goodness function above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5c8569be",
      "metadata": {
        "id": "5c8569be"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def overlay_y_on_x(x, y):\n",
        "    \"\"\"Replace the first 10 pixels of data [x] with one-hot-encoded label [y]\"\"\"\n",
        "    x_ = x.clone()\n",
        "    x_[:, :10] *= 0.0  # zero out the first 10 pixels of input data\n",
        "    x_[range(x.shape[0]), y] = x.max()  # the y-label 'activates' the corresponding pixel\n",
        "    return x_\n",
        "\n",
        "def visualize_sample(data, name='', idx=0):\n",
        "    reshaped = data[idx].cpu().reshape(28, 28)\n",
        "    plt.figure(figsize = (4, 4))\n",
        "    plt.title(name)\n",
        "    plt.imshow(reshaped, cmap=\"gray\")\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "68d57642",
      "metadata": {
        "id": "68d57642"
      },
      "source": [
        "### 4.2 Visualize Data Samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ym6ZuGxRkyji",
      "metadata": {
        "id": "ym6ZuGxRkyji"
      },
      "outputs": [],
      "source": [
        "torch.manual_seed(123)\n",
        "net = Net([784, 500, 500], \"lif\").to(device) # construct neural net with 784 input neurons, 500 neurons in each of the two hidden layers\n",
        "x, y = next(iter(train_loader))\n",
        "x, y = x.to(device), y.to(device)\n",
        "\n",
        "x_pos = overlay_y_on_x(x, y)  # generate positive samples (i.e., correct labels overlayed on input data)\n",
        "\n",
        "# generate negative samples (i.e., incorrect labels overlayed on input data)\n",
        "rnd = torch.randperm(x.size(0))\n",
        "x_neg = overlay_y_on_x(x, y[rnd])  #\n",
        "\n",
        "# visualize samples\n",
        "for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):\n",
        "    visualize_sample(data, name)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "568475ac",
      "metadata": {
        "id": "568475ac"
      },
      "source": [
        "### 4.3 Training the Network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "63be0e05",
      "metadata": {
        "id": "63be0e05"
      },
      "outputs": [],
      "source": [
        "loss = net.train(x_pos, x_neg)\n",
        "print('Train error:', 100*(1.0 - net.predict(x).eq(y).float().mean().item()),'%')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "LTJonZ4elIXf",
      "metadata": {
        "id": "LTJonZ4elIXf"
      },
      "source": [
        "## 5. Testing the Network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "74rCTBPRlK1u",
      "metadata": {
        "id": "74rCTBPRlK1u"
      },
      "outputs": [],
      "source": [
        "x_te, y_te = next(iter(test_loader))\n",
        "x_te, y_te = x_te.to(device), y_te.to(device)\n",
        "\n",
        "print('Test error:', 100*(1.0 - net.predict(x_te).eq(y_te).float().mean().item()), '%')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cd47411c",
      "metadata": {
        "id": "cd47411c"
      },
      "source": [
        "It's not great, but running more time-steps of simulations across more epochs should help. Give it a go."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "43Tf1Hpfkn7m",
      "metadata": {
        "id": "43Tf1Hpfkn7m"
      },
      "source": [
        "## 6. Plotting the Loss of Each Layer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cN7wLeKGgXq3",
      "metadata": {
        "id": "cN7wLeKGgXq3"
      },
      "outputs": [],
      "source": [
        "plt.plot(loss[0,:], label=\"Layer 1 Loss\")\n",
        "plt.plot(loss[1,:], label=\"Layer 2 Loss\")\n",
        "plt.xlabel(\"Epoch\")\n",
        "plt.ylabel(\"Loss\")\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "E0ncALQA5T7N",
      "metadata": {
        "id": "E0ncALQA5T7N"
      },
      "source": [
        "## 7. Train the Network using ReLU instead of LIF Activation\n",
        "\n",
        "*   To compare, we can also train and test the same network architecture on the same data, except now using ReLU as the activation function instead of LIF.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Y40FBDxD5Pf5",
      "metadata": {
        "id": "Y40FBDxD5Pf5"
      },
      "outputs": [],
      "source": [
        "net = Net([784, 500, 500], \"relu\").to(device)\n",
        "\n",
        "# load training data\n",
        "x, y = next(iter(train_loader))\n",
        "x, y = x.to(device), y.to(device)\n",
        "\n",
        "# overlay labels on data\n",
        "x_pos = overlay_y_on_x(x, y)\n",
        "rnd = torch.randperm(x.size(0))\n",
        "x_neg = overlay_y_on_x(x, y[rnd])\n",
        "\n",
        "# visaualize samples\n",
        "for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):\n",
        "    visualize_sample(data, name)\n",
        "\n",
        "net.train(x_pos, x_neg)\n",
        "print('train error:', 100*(1.0 - net.predict(x).eq(y).float().mean().item()),'%')\n",
        "\n",
        "# load test data\n",
        "x_te, y_te = next(iter(test_loader))\n",
        "x_te, y_te = x_te.to(device), y_te.to(device)\n",
        "print('test error:', 100*(1.0 - net.predict(x_te).eq(y_te).float().mean().item()),'%')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Oa2jm154zvGH",
      "metadata": {
        "id": "Oa2jm154zvGH"
      },
      "source": [
        "# Comparison against ANNs and SNNs using Backpropagation\n",
        "As shown in [this Keras example](https://www.tensorflow.org/datasets/keras_example), a standard artificial neural network using backpropagation with similar architecure and hyperparameters can attain a test set error of about 2.50%.\n",
        "\n",
        "As shown in [the snnTorch Tutorial 5 documentation](https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_5.html), a spiking neural network using backpropagation with the same architecture and similar set hyperparameters attains a test set error of about 6.13%.\n",
        "\n",
        "See if you can modify the code to bridge that gap."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0aIXWPOHkiD1",
      "metadata": {
        "id": "0aIXWPOHkiD1"
      },
      "source": [
        "# References\n",
        "> <cite> [Geoffrey Hinton. The Forward-Forward Algorithm: Some Preliminary\n",
        "Investigations. 2022. arXiv: 2212.13345 [cs.LG].](https://arxiv.org/abs/2212.13345) </cite>\n",
        "\n",
        "> <cite> [Mohammad Pezeshki. pytorch_forward_forward. Github Repo January 2023.](https://github.com/mpezeshki/pytorch_forward_forward) </cite>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Acknowledgments\n",
        "This project was supported in-part by the National Science Foundation RCN-SC 2332166."
      ],
      "metadata": {
        "id": "OLfY9dc_hKYi"
      },
      "id": "OLfY9dc_hKYi"
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "snnenv-home",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}